"""
采样部分部分标签数据，检验transformation的性质。探究model bias，可以视为一种 model selection。
1-5: shape,scaling,orientation,x,y
"""
import argparse
from collections import defaultdict

import numpy as np
import wandb
from torch import optim
from torch.utils.data import DataLoader
from tqdm import trange
from fastai.vision import *
from disvae import Evaluator
from disvae.models.anneal import *
from disvae.models.decoders import get_decoder
from disvae.models.encoders import get_encoder
from disvae.models.losses import *
from utils.datasets import get_dataloaders, DSprites
from utils.helpers import set_seed
from utils.visualize import *

hyperparameter_defaults = dict(
    batch_size=512,
    learning_rate=0.0005,
    epochs=10,
    transformation=1,
    loss='betaH',
    random_seed=156,
    beta=4,
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args
set_seed(config.random_seed)


def get_preds(model, dl):
    model.eval()
    preds = []
    reals = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).cuda()
            mu, _ = model.encoder(img)
            preds.append(mu)
            reals.append(img)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, torch.cat(reals), targets


def evaluate(model, dl, storer):
    preds, img, targets = get_preds(model, dl)

    kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
    _, index = kl_loss.sort(descending=True)

    if len(index) >= 3:
        points = preds[:1000, index[:3]].cpu()
        colors = targets[:1000]
        fig, axes = plt.subplots(3, 3)
        for i in range(3):  # factor
            for j in range(3):  # variable
                axes[i, j].scatter(points[:, j].numpy(), colors[:, i].numpy(), s=0.2)
                axes[i, j].set_title(f'{i},{index[j]}')
        storer['correlated'] = wandb.Image(fig)

    plt.close()


def inversion(seq):
    s1 = 0
    s2 = 0
    for i in range(1, seq.size(1)):
        for j in range(i):
            s1 += (seq[:, j] > seq[:, i]).long()
            s2 += (seq[:, j] <= seq[:, i]).long()
    t = torch.stack([s1, s2])
    return t.min(0)[0], t.max(0)[0]


def plt_sample_traversal(sample, label, model, traversal_len=5, dim_list=range(4), r=3):
    dim_len = len(dim_list)
    with torch.no_grad():
        if sample is not None:
            mu, _ = model.encoder(sample)
        else:
            mu = torch.zeros(1, dim_len)
        mu = torch.cat([mu, label], 1)
        fig, axes = plt.subplots(dim_len, traversal_len,
                                 figsize=(traversal_len * 4, dim_len * 4))
        axes = axes.reshape(dim_len, 7)
        plt.tight_layout()

        for i, dim in enumerate(dim_list):
            base_latents = mu
            linear_traversal = torch.linspace(-r, r, traversal_len)
            traversals = traversal_latents(base_latents, linear_traversal, dim)
            recon_batch = model.decoder(traversals)

            plot_bar(axes[i], recon_batch[:, 0])

        return fig

def train(dl, model, epochs, loss_f, transformation):
    opt = optim.Adam(model.parameters(), config.learning_rate, weight_decay=0)
    storer = defaultdict(list)
    itr = 0
    for e in trange(epochs):
        for i, (img, label) in enumerate(dl):
            img = img.view(-1, 1, 64, 64).cuda().float()
            label = label.float().clone().cuda()
            label[:, transformation] = 0
            recon_batch, latent_dist, latent_sample = model(img, label)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            storer['loss'].append(loss.item())
            itr += 1
            if itr % 100 == 0:
                for k, v in storer.items():
                    if isinstance(v, list):
                        storer[k] = np.mean(v)
                storer['itr'] = itr
                # evaluate(model,train_loader,storer)
                wandb.log(storer, sync=False)
                storer = defaultdict(list)

        model.eval()
        fig = plt_sample_traversal(img[:1], label[:1], vae, 7, [0])
        storer['traversal'] = wandb.Image(fig)

        params_zCX, labels = evaluator.compute(test_loader)
        mu = params_zCX[0].view([3, 6, 40, 32, 32, 1]).cpu()
        labels = labels.view([3, 6, 40, 32, 32, 6])
        order = mu.transpose(transformation - 1, 5).argsort(5)
        gt_order = labels[..., transformation:transformation + 1].transpose(transformation - 1, 5) - 1
        diff = F.mse_loss(gt_order.float(), order.float()).item()

        elements = order.size(5)
        order = order.reshape(-1, elements)
        gt_order = gt_order.reshape(-1, elements)
        x = torch.arange(len(order)).reshape(-1, 1).expand(len(order), elements)
        seq = gt_order[x, order]
        s_min, s_max = inversion(seq)

        wandb.log({'diff': diff,
                   's_min': s_min.float().mean().item(),
                   's_max': s_max.float().mean().item(),
                   'inv': (s_min.float() / s_max.float()).mean().item()})

        model.train()

    for k, v in storer.items():
        if isinstance(v, list):
            storer[k] = np.mean(v)
    return storer


beta = args.beta
transformation = config.transformation
assert transformation >= 1 and transformation <= 5
epochs = config.epochs

dsprites = DSprites()


class VAE(nn.Module):
    def __init__(self, img_size, encoder, decoder, ):
        """
        Class which defines model and forward pass.

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).
        """
        super(VAE, self).__init__()

        if list(img_size[1:]) not in [[32, 32], [64, 64]]:
            raise RuntimeError(
                "{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(
                    img_size))

        self.img_size = img_size
        self.num_pixels = self.img_size[1] * self.img_size[2]
        self.encoder = encoder(img_size, 1)
        self.decoder = decoder(img_size, 7)
        self.latent_dim = 1

    def reparameterize(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (batch_size, latent_dim)

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (batch_size,
            latent_dim)
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + std * eps
        else:
            # Reconstruction mode
            return mean

    def forward(self, x, label):
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        z = torch.cat([latent_sample, label], 1)
        reconstruct = self.decoder(z)
        return reconstruct, latent_dist, latent_sample


# dsprites.lat_values[:, transformation] = 0
train_loader = DataLoader(dsprites, num_workers=4,
                          batch_size=config.batch_size,
                          shuffle=True)

wandb.init(project="week_supervised", group=__file__.split('/')[-1][:-3], notes=__doc__, tags=['equilibrium'],
           config=args, )

image_size = (1, 64, 64)
vae = VAE(image_size, get_encoder('Burgess'), get_decoder('Burgess'))
vae.cuda()
vae.train()

anneal = get_anneal('constant', epochs, 1, 1)
loss_f = get_loss_s(config.loss, anneal, beta, len(train_loader.dataset))

test_loader = DataLoader(dsprites,
                         batch_size=config.batch_size,
                         num_workers=4,
                         shuffle=False, )
# anneal

anneal = get_anneal('constant', 1, 1, 1)

evaluator = Evaluator(vae, loss_f,
                      device='cuda',
                      save_dir=wandb.run.dir,
                      is_progress_bar=False)

storer = train(train_loader, vae, epochs, loss_f, transformation)
#
# torch.save(vae.state_dict(),'tmp.pt')
# vae.load_state_dict(torch.load('tmp.pt'))
vae.eval()
